# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py

import inspect
from typing import Callable, List, Optional, Union, Tuple
from dataclasses import dataclass

import math
import numpy as np
import torch
from tqdm import tqdm

from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer

from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import (
    DDIMScheduler,
    DPMSolverMultistepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    LMSDiscreteScheduler,
    PNDMScheduler,
)
from diffusers.utils import deprecate, logging, BaseOutput

from einops import rearrange

from ..models.unet import UNet3DConditionModel

import os
from scipy.spatial.distance import cdist
import ImageReward as RM
import torch.nn.functional as F
from animatediff.pipelines.freeinit_utils import (
    get_freq_filter,
    freq_mix_3d,
    get_low_freq,
    image_noise_freq_mix_3d
)
from animatediff.utils.gaussian_smoothing import GaussianSmoothing, fn_smoothing_func
import torch.utils.checkpoint as checkpoint
from torch.optim.adam import Adam
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


@dataclass
class AnimationPipelineOutput(BaseOutput):
    videos: Union[torch.Tensor, np.ndarray]
    mask: Union[torch.Tensor, np.ndarray, int]


class AnimationPipeline(DiffusionPipeline):
    _optional_components = []

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet3DConditionModel,
        scheduler: Union[
            DDIMScheduler,
            PNDMScheduler,
            LMSDiscreteScheduler,
            EulerDiscreteScheduler,
            EulerAncestralDiscreteScheduler,
            DPMSolverMultistepScheduler,
        ],
    ):
        super().__init__()

        self.image_reward_model = RM.load("", med_config='')

        if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
                f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
                "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
                " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
                " file"
            )
            deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["steps_offset"] = 1
            scheduler._internal_dict = FrozenDict(new_config)

        if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
                " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
                " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
                " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
                " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
            )
            deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["clip_sample"] = False
            scheduler._internal_dict = FrozenDict(new_config)

        is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
            version.parse(unet.config._diffusers_version).base_version
        ) < version.parse("0.9.0.dev0")
        is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
            deprecation_message = (
                "The configuration file of the unet has set the default `sample_size` to smaller than"
                " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
                " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
                " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
                " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
                " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
                " in the config might lead to incorrect results in future versions. If you have downloaded this"
                " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
                " the `unet/config.json` file"
            )
            deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(unet.config)
            new_config["sample_size"] = 64
            unet._internal_dict = FrozenDict(new_config)

        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
        )
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

    def enable_vae_slicing(self):
        self.vae.enable_slicing()

    def disable_vae_slicing(self):
        self.vae.disable_slicing()

    def enable_sequential_cpu_offload(self, gpu_id=0):
        if is_accelerate_available():
            from accelerate import cpu_offload
        else:
            raise ImportError("Please install accelerate via `pip install accelerate`")

        device = torch.device(f"cuda:{gpu_id}")

        for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
            if cpu_offloaded_model is not None:
                cpu_offload(cpu_offloaded_model, device)

    @property
    def _execution_device(self):
        if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
            return self.device
        for module in self.unet.modules():
            if (
                hasattr(module, "_hf_hook")
                and hasattr(module._hf_hook, "execution_device")
                and module._hf_hook.execution_device is not None
            ):
                return torch.device(module._hf_hook.execution_device)
        return self.device

    def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
        batch_size = len(prompt) if isinstance(prompt, list) else 1

        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
        untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
            logger.warning(
                "The following part of your input was truncated because CLIP can only handle sequences up to"
                f" {self.tokenizer.model_max_length} tokens: {removed_text}"
            )

        if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
            attention_mask = text_inputs.attention_mask.to(device)
        else:
            attention_mask = None

        text_embeddings = self.text_encoder(
            text_input_ids.to(device),
            attention_mask=attention_mask,
        )
        text_embeddings = text_embeddings[0]

        # duplicate text embeddings for each generation per prompt, using mps friendly method
        bs_embed, seq_len, _ = text_embeddings.shape
        text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
        text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)

        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance:
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            elif type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

            max_length = text_input_ids.shape[-1]
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = uncond_input.attention_mask.to(device)
            else:
                attention_mask = None

            uncond_embeddings = self.text_encoder(
                uncond_input.input_ids.to(device),
                attention_mask=attention_mask,
            )
            uncond_embeddings = uncond_embeddings[0]

            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
            seq_len = uncond_embeddings.shape[1]
            uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
            uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        return text_embeddings

    def decode_latents(self, latents):
        video_length = latents.shape[2]
        latents = 1 / 0.18215 * latents
        latents = rearrange(latents, "b c f h w -> (b f) c h w")
        # video = self.vae.decode(latents).sample
        video = []
        for frame_idx in tqdm(range(latents.shape[0])):
            video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
        video = torch.cat(video)
        video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
        video = (video / 2 + 0.5).clamp(0, 1)
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
        video = video.cpu().float().numpy()
        return video

    def prepare_extra_step_kwargs(self, generator, eta):
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]

        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # check if the scheduler accepts generator
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        return extra_step_kwargs

    def check_inputs(self, prompt, height, width, callback_steps):
        if not isinstance(prompt, str) and not isinstance(prompt, list):
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

        if (callback_steps is None) or (
            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
        ):
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )

    def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
        shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )
        if latents is None:
            rand_device = "cpu" if device.type == "mps" else device

            if isinstance(generator, list):
                shape = shape
                # shape = (1,) + shape[1:]
                latents = [
                    torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
                    for i in range(batch_size)
                ]
                latents = torch.cat(latents, dim=0).to(device)
            else:
                latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
        else:
            if latents.shape != shape:
                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
            latents = latents.to(device)

        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * self.scheduler.init_noise_sigma
        return latents

    @torch.no_grad()
    def init_filter(self, video_length, height, width, filter_params):
        # initialize frequency filter for noise reinitialization
        batch_size = 1
        num_channels_latents = self.unet.in_channels
        filter_shape = [
            batch_size, 
            num_channels_latents, 
            video_length, 
            height // self.vae_scale_factor, 
            width // self.vae_scale_factor
        ]
        # self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params)
        self.freq_filter = get_freq_filter(
            filter_shape, 
            device=self._execution_device, 
            filter_type=filter_params.method,
            n=filter_params.n if filter_params.method=="butterworth" else None,
            d_s=filter_params.d_s,
            d_t=filter_params.d_t
        )

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]],
        video_length: Optional[int],
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_videos_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "tensor",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: Optional[int] = 1,
        is_main_process: bool = True,
        savedir: str = '',
        scale_factor: int = 20,
        scale_range: Tuple[float, float] = (1., 0.5),
        **kwargs,
    ):
        
        use_fp16 = True
        if use_fp16:
            print('Warning: using half percision for inferencing!')
            self.vae.to(dtype=torch.float16)
            self.unet.to(dtype=torch.float16)
            self.text_encoder.to(dtype=torch.float16)

        # Default height and width to unet
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        # Check inputs. Raise error if not correct
        self.check_inputs(prompt, height, width, callback_steps)

        # Define call parameters
        # batch_size = 1 if isinstance(prompt, str) else len(prompt)
        batch_size = 1
        if latents is not None:
            batch_size = latents.shape[0]
        if isinstance(prompt, list):
            batch_size = len(prompt)

        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # Encode input prompt
        prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
        if negative_prompt is not None:
            negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size 
        text_embeddings = self._encode_prompt(
            prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
        )

        # Prepare latent variables
        num_channels_latents = self.unet.in_channels
        latents = self.prepare_latents(
            batch_size * num_videos_per_prompt,
            num_channels_latents,
            video_length,
            height,
            width,
            text_embeddings.dtype,
            device,
            generator,
            latents,
        )
        latents_dtype = latents.dtype

        # Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # Prepare extra step kwargs.
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
        scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps))
        # print('[scale_range]', scale_range)

        # Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        number_denoising = 2
        with self.progress_bar(total=num_inference_steps * number_denoising) as progress_bar:
            
            optimal_frame_idx, optimal_frame_reward = 0, 0
            for j in range(number_denoising):
                
                if j == 0:
                    # init_latents = latents.clone()

                    init_latents = self.prepare_latents(
                        batch_size * num_videos_per_prompt,
                        num_channels_latents,
                        video_length,
                        height,
                        width,
                        text_embeddings.dtype,
                        device,
                        generator,
                        latents=None,
                    )

                    print('[1 latents]', latents.size())
                    latents = rearrange(latents, "b c f h w -> (b f) c h w").contiguous().unsqueeze(2)
                    # latents = torch.randn_like(latents)
                    print('[2 latents]', latents.size())

                    # print('[text_embeddings]', text_embeddings.size())
                    frame_text_embeddings = torch.cat([text_embeddings[0:1]] * video_length + [text_embeddings[1:2]] * video_length, dim=0)


                    num_inference_steps = 25
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
                    

                    for i, t in enumerate(timesteps):

                        # expand the latents if we are doing classifier free guidance
                        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                        # predict the noise residual
                        noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=frame_text_embeddings).sample.to(dtype=latents_dtype)

                        # perform guidance
                        if do_classifier_free_guidance:
                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                        # compute the previous noisy sample x_t -> x_t-1
                        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

                        # call the callback, if provided
                        if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                            progress_bar.update()
                            if callback is not None and i % callback_steps == 0:
                                callback(i, t, latents)

                 
                    video_frame = self.decode_latents(latents)      # tensor
                    print('[video_frame]', video_frame.shape, video_frame.min(), video_frame.max())
                    
                    video_frame_pil_list = []
                    for i in range(video_length):
                        frame_pil = video_frame[i:i + 1]
                        frame_pil = np.squeeze(frame_pil, axis=2)
                        frame_pil = frame_pil.transpose(0, 2, 3, 1)
                        frame_pil = self.numpy_to_pil(frame_pil)

                        video_frame_pil_list = video_frame_pil_list + frame_pil

                    with torch.no_grad():
                        rewards = self.image_reward_model.score(prompt, video_frame_pil_list)
                        for frame_idx, frame_reward in enumerate(rewards):
                            if frame_reward > optimal_frame_reward:
                                optimal_frame_idx, optimal_frame_reward = frame_idx, frame_reward
                        


                elif j == 1:
                    latents = latents[optimal_frame_idx:optimal_frame_idx + 1]
                    if j == 1: latents = torch.cat([latents] * (video_length // 1), dim=2)

                    ucond_text_embeddings_list, text_embeddings_list = [], []
                    for k in range(video_length):

                        augmented_prompt = prompt[0]
                        augmented_prompt = [augmented_prompt]

           
                        print(k, augmented_prompt)
                        _text_embeddings = self._encode_prompt(
                            augmented_prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
                        )
                        ucond_text_embeddings_list.append(_text_embeddings[0:1])
                        text_embeddings_list.append(_text_embeddings[1:2])
                    frame_text_embeddings = torch.cat(ucond_text_embeddings_list + text_embeddings_list, dim=0)
                    print('[frame_text_embeddings]', frame_text_embeddings.size())


                    # ***************
                    # SDS
                    # ***************
                    if True:
                    
                        print('[1 latents]', latents.mean().item())
                        num_inference_steps = 25
            
                        # self.scheduler.set_timesteps(num_inference_steps, device=device)
                        strength = 1.
                        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength=strength, device='cuda')

                        # with torch.enable_grad():
                        if True:
                            adam_latents = latents.clone().detach()

                            _noise = self.prepare_latents(
                                    batch_size * num_videos_per_prompt,
                                    num_channels_latents,
                                    video_length,
                                    height,
                                    width,
                                    text_embeddings.dtype,
                                    device,
                                    generator,
                                    latents=None,)
                            
                            # for sds_i, t in enumerate(timesteps):
                            for sds_i in range(0, 15, 1):
          
                                t = timesteps[sds_i % len(timesteps)]

                              
                                
                                # _noise = init_latents
                                noise_latents = self.scheduler.add_noise(adam_latents, _noise, t)

                                latent_model_input = torch.cat([noise_latents] * 2) if do_classifier_free_guidance else noise_latents
                                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
                                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=frame_text_embeddings).sample.to(dtype=latents_dtype)
                                
                                # perform guidance
                                if do_classifier_free_guidance:
                                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)


                                w = (1 - self.scheduler.alphas_cumprod[t])          # sds
                                # w = 1                                               # uniform
                                # w = (self.scheduler.alphas_cumprod[t] ** 0.5 * (1 - self.scheduler.alphas_cumprod[t]))  # fantasia3d


                                grad = w * (noise_pred - _noise)

                                step_size = 1.
                                adam_latents = adam_latents - step_size * grad
                                
                                # adam_latents = torch.cat([latents[:, :, :1, :, :], adam_latents[:, :, 1:, :, :]], dim=2)
                                # adam_latents = adam_latents * 0.8 + latents * 0.2


                                adam_latents_mean   = adam_latents.mean()
                                grad_mean           = grad.mean()
                                print(sds_i, t.item(), len(timesteps), '\t', f'{adam_latents_mean:0.8f}', '[adam_latents]', grad.size(), f'{grad_mean:0.8f}', f'{w:0.8f}')
                                # optimizer.step()

                        print('[2 latents]', latents.mean().item())

                        latents = adam_latents.clone()



                    num_inference_steps = 25
           
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
                    strength = 0
                    print('[self.scheduler.timesteps]', self.scheduler.timesteps)

                    timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength=strength, device='cuda')

                    # start_timestep = torch.tensor([999]).cuda()
                    # timesteps = torch.cat([start_timestep, timesteps], dim=0)
                    # print('[video timesteps]', timesteps)

                    # latent_timestep = timesteps[:1].repeat(1 * 1)
                    latent_timestep = torch.tensor([999]).cuda()
                    print('[latent_timestep]', latent_timestep)
                  
                    # latents = self.scheduler.add_noise(latents, init_latents, latent_timestep)

                    video_latents = latents.clone()



                    alpha = 1
                    # noise = torch.randn_like(latents)
                    noise = init_latents
                    latents = noise * (1 - alpha) + latents * alpha

           



                    # mask
                    _B, _C, _L, _H, _W = latents.size()

                    c_frame = video_length // 2 - 1
                    mask = self.fn_get_mask(_B, _C, _L, _H, _W, ratio=1. / 4, optimal_frame_idx=c_frame, decay=0.).to(dtype=latents_dtype)
                    latents = latents * mask + init_latents * (1 - mask)

       
       
                    



                    ucond_text_embeddings_list, text_embeddings_list = [], []
                    for k in range(video_length):
                        # augmented_prompt = augmented_prefix[k] + ' ' + prompt[0] 
                        # augmented_prompt = ""
                        augmented_prompt = prompt[0]
                        # augmented_prompt = 'Highly dynamic, significant movement. ' + prompt[0]
                        
                    
                        augmented_prompt = [augmented_prompt]

                       
                        print(k, augmented_prompt)
                        _text_embeddings = self._encode_prompt(
                            augmented_prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
                        )
                        ucond_text_embeddings_list.append(_text_embeddings[0:1])
                        text_embeddings_list.append(_text_embeddings[1:2])
                        # text_embeddings_list.append(prompt_embeds)
                    video_frame_text_embeddings = torch.cat(
                        ucond_text_embeddings_list + text_embeddings_list,
                        dim=0)
                    print('[video_frame_text_embeddings]', video_frame_text_embeddings.size())


                    start, end = -1, -1
                    length_timesteps, i, repeat_i = len(timesteps), 0, 0
                    # for i, t in enumerate(timesteps):


                    while i < length_timesteps:

                        # if i > 1 and i < 4: 
                        #     i = i + 1
                        #     continue
                      
                        t = timesteps[i]
                        

                        # expand the latents if we are doing classifier free guidance
                        if i == 0: latent_model_input = torch.cat([latents, latents] * 1) if do_classifier_free_guidance else latents
                        else: latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                        
                        
                        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                        # predict the noise residual
                        if repeat_i == 0: frame_text_embeddings = video_frame_text_embeddings
                        else: frame_text_embeddings = video_frame_text_embeddings
                        noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=frame_text_embeddings).sample.to(dtype=latents_dtype)

                        # perform guidance
                        if do_classifier_free_guidance:
                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                        # compute the previous noisy sample x_t -> x_t-1
                        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

                        # call the callback, if provided
                        if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                            progress_bar.update()
                            if callback is not None and i % callback_steps == 0:
                                callback(i, t, latents)


                       
                    
                        i = i + 1
                        
                        repeat_mark = -1
                        if i < repeat_mark:  # 
           
                            if repeat_i < 1:
                           
                                if i <= 5: t0, tMax = timesteps[i] + 1, timesteps[i - 1] + 1
                                else: t0, tMax = timesteps[i] + 1, timesteps[i - 1] + 1

                                print('[t0, tMax]', t0, tMax)

                                # eps
                                eps = self.prepare_latents(
                                    batch_size * num_videos_per_prompt,
                                    num_channels_latents,
                                    video_length,
                                    height,
                                    width,
                                    text_embeddings.dtype,
                                    device,
                                    generator,
                                    latents=None,
                                )



                                # latents = self.DDPM_forward(latents, t0, tMax, generator, eps=None)
                                latents = self.DDPM_forward(latents, t0, tMax, generator, eps=eps)
                                # latents = self.DDPM_forward(latents, t, timesteps[0] + 1, generator)

                                print('[latents]', latents.size())

                                


                                
                                i = i - 1
                                repeat_i = repeat_i + 1

                    
                            else:
                                repeat_i = 0
                                

                    # latents = latents.squeeze(2)
                    # latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length).contiguous()
                
                else:
                    # num_inference_steps = 25
                    if j == 2: num_inference_steps = 10
                    elif j == 3: num_inference_steps = 15
                    elif j == 4: num_inference_steps = 20
                    else: num_inference_steps = 25

                    self.scheduler.set_timesteps(num_inference_steps, device=device)
                    strength = 1
                    print('[self.scheduler.timesteps]', self.scheduler.timesteps)

                    timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength=strength, device='cuda')
                    latent_timestep = timesteps[:1].repeat(1 * 1)

                    # eps
                    eps = self.prepare_latents(
                        batch_size * num_videos_per_prompt,
                        num_channels_latents,
                        video_length,
                        height,
                        width,
                        text_embeddings.dtype,
                        device,
                        generator,
                        latents=None,
                    )


                    latents = self.scheduler.add_noise(latents, eps, latent_timestep)

                    # z_rand = torch.randn_like(latents)
                    # latents = freq_mix_3d(latents.to(dtype=torch.float32), z_rand.to(dtype=torch.float32), LPF=self.freq_filter)
                    # latents = latents.to(latents_dtype)

                    ucond_text_embeddings_list, text_embeddings_list = [], []
                    for k in range(video_length):
                        # augmented_prompt = augmented_prefix[k] + prompt[0]
                        augmented_prompt = prompt[0]
                        augmented_prompt = [augmented_prompt]
                        print(k, augmented_prompt)
                        _text_embeddings = self._encode_prompt(
                            augmented_prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
                        )
                        ucond_text_embeddings_list.append(_text_embeddings[0:1])
                        text_embeddings_list.append(_text_embeddings[1:2])
                    frame_text_embeddings = torch.cat(
                        ucond_text_embeddings_list + text_embeddings_list,
                        dim=0)
                    print('[frame_text_embeddings]', frame_text_embeddings.size())

                    for i, t in enumerate(timesteps):

                        # expand the latents if we are doing classifier free guidance
                        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                        # predict the noise residual
                        noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=frame_text_embeddings).sample.to(dtype=latents_dtype)

                        # perform guidance
                        if do_classifier_free_guidance:
                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                        # compute the previous noisy sample x_t -> x_t-1
                        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

                        # call the callback, if provided
                        if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                            progress_bar.update()
                            if callback is not None and i % callback_steps == 0:
                                callback(i, t, latents)

        # Post-processing
        video = self.decode_latents(latents)

        # Convert to tensor
        if output_type == "tensor":
            video = torch.from_numpy(video)
        
        if output_type == "pil":
            video = video[:, :, 0:1, :, :]
            video = np.squeeze(video, axis=2)
            video = video.transpose(0, 2, 3, 1)
            print('[video]', video.shape)
            video = self.numpy_to_pil(video)

        if not return_dict:
            return video
        
        mask = rearrange(mask, "b c f h w -> (b f) c h w").contiguous()
        mask = F.interpolate(mask, scale_factor=8, mode='nearest')
        mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length).contiguous()

        mask = mask.cpu()

        return AnimationPipelineOutput(videos=video, mask=mask)


    def get_timesteps(self, num_inference_steps, strength, device):
        
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = self.scheduler.timesteps[t_start:]

        return timesteps, num_inference_steps - t_start
    
    def DDPM_forward(self, x0, t0, tMax, generator, shape=None, eps=None):
        
        device = x0.device
        rand_device = "cpu" if device.type == "mps" else device
        if x0 is None: return torch.randn(shape, generator=generator, device=rand_device, dtype=x0.dtype).to(device)
        else:
            eps = torch.randn(x0.shape, dtype=x0.dtype, generator=generator, device=rand_device) if eps is None else eps
            alpha_vec = torch.prod(self.scheduler.alphas[t0:tMax])
            xt = torch.sqrt(alpha_vec) * x0 + torch.sqrt(1-alpha_vec) * eps
            return xt
    
   
    
